use baml_types::ir_type::TypeGeneric;
use dir_writer::{FileCollector, GeneratorArgs, IntermediateRepr, LanguageFeatures};
use functions::{
    render_functions, render_functions_stream, render_runtime_code, render_source_files,
    render_type_map,
};
use generated_types::{render_go_stream_types, render_go_types};
use internal_baml_core::ir::TypeValue;

use crate::{
    functions::{render_functions_parse, render_functions_parse_stream},
    generated_types::{
        render_go_stream_types_utils, render_go_types_utils, render_type_builder_classes,
        render_type_builder_common, render_type_builder_enums,
    },
};

mod functions;
mod generated_types;
mod ir_to_go;
mod package;
mod r#type;
mod utils;

#[derive(Default)]
pub struct GoLanguageFeatures;

impl LanguageFeatures for GoLanguageFeatures {
    const CONTENT_PREFIX: &'static str = r#"
// ----------------------------------------------------------------------------
//
//  Welcome to Baml! To use this generated code, please run the following:
//
//  $ go get github.com/boundaryml/baml
//
// ----------------------------------------------------------------------------

// This file was generated by BAML: please do not edit it. Instead, edit the
// BAML files and re-generate this code using: baml-cli generate
// You can install baml-cli with:
//  $ go install github.com/boundaryml/baml/baml-cli
        "#;

    fn name() -> &'static str {
        "go"
    }

    fn generate_sdk_files(
        &self,
        collector: &mut FileCollector<Self>,
        ir: std::sync::Arc<IntermediateRepr>,
        args: &GeneratorArgs,
    ) -> Result<(), anyhow::Error> {
        let Some(go_mod_name) = &args.client_package_name else {
            anyhow::bail!("Go client package name is required");
        };

        let pkg = package::CurrentRenderPackage::new("baml_client", ir.clone());
        let file_map = args.file_map_as_json_string()?;
        collector.add_file("baml_source_map.go", render_source_files(file_map)?)?;
        collector.add_file("runtime.go", render_runtime_code(&pkg)?)?;
        let functions = ir
            .functions
            .iter()
            .map(|f| ir_to_go::functions::ir_function_to_go(f, &pkg))
            .collect::<Vec<_>>();
        collector.add_file(
            "functions.go",
            render_functions(&functions, &pkg, go_mod_name)?,
        )?;

        collector.add_file(
            "functions_stream.go",
            render_functions_stream(&functions, &pkg, go_mod_name)?,
        )?;

        collector.add_file(
            "functions_parse.go",
            render_functions_parse(&functions, &pkg, go_mod_name)?,
        )?;

        collector.add_file(
            "functions_parse_stream.go",
            render_functions_parse_stream(&functions, &pkg, go_mod_name)?,
        )?;

        let go_classes = ir
            .walk_classes()
            .map(|c| ir_to_go::classes::ir_class_to_go(c.item, &pkg))
            .collect::<Vec<_>>();
        let enums = ir
            .walk_enums()
            .map(|e| ir_to_go::enums::ir_enum_to_go(e.item, &pkg))
            .collect::<Vec<_>>();
        let unions = {
            let mut unions = ir
                .walk_all_non_streaming_unions()
                .filter_map(|t| ir_to_go::unions::ir_union_to_go(&t, &pkg))
                .collect::<Vec<_>>();
            // dedup by name!
            unions.sort_by_key(|u| u.name.clone());
            unions.dedup_by_key(|u| u.name.clone());
            unions
        };
        let type_aliases = ir.walk_type_aliases().collect::<Vec<_>>();

        // key-value pair of what type to drop from the cycle for any given type
        let invalid_cycles = ir
            .structural_recursive_alias_cycles()
            .iter()
            .filter(|&cycle| {
                // find all cycles considered_invalid in go
                cycle.iter().all(|(_, field_type)| {
                    // must have at least one non-recursive type
                    field_type
                        .find_if(
                            &|t| match t {
                                TypeGeneric::Class { .. } => true,
                                TypeGeneric::Enum { .. } => true,
                                TypeGeneric::Literal(..) => true,
                                TypeGeneric::Primitive(TypeValue::Null, ..) => false,
                                TypeGeneric::Primitive(..) => true,
                                _ => false,
                            },
                            true,
                        )
                        .is_empty()
                })
            })
            .flat_map(|cycle| {
                let keys = cycle.keys().cloned().collect::<Vec<_>>();
                let first_key = keys[0].clone();
                keys.into_iter().map(move |k| (k, first_key.clone()))
            })
            .collect::<baml_types::BamlMap<_, _>>();

        let mut go_type_aliases = type_aliases
            .iter()
            .map(|c| {
                ir_to_go::type_aliases::ir_type_alias_to_go(
                    c.item,
                    &pkg,
                    invalid_cycles.get(&c.elem().name),
                )
            })
            .collect::<Vec<_>>();
        go_type_aliases.sort_by(|a, b| a.name.cmp(&b.name));

        let mut stream_type_aliases = type_aliases
            .iter()
            .map(|c| {
                ir_to_go::type_aliases::ir_type_alias_to_go_stream(
                    c.item,
                    &pkg,
                    invalid_cycles.get(&c.elem().name),
                )
            })
            .collect::<Vec<_>>();
        stream_type_aliases.sort_by(|a, b| a.name.cmp(&b.name));

        let stream_unions = {
            let mut unions = ir
                .walk_all_streaming_unions()
                .filter_map(|t| ir_to_go::unions::ir_union_to_go_stream(&t, &pkg))
                .collect::<Vec<_>>();
            // dedup by name!
            unions.sort_by_key(|u| u.name.clone());
            unions.dedup_by_key(|u| u.name.clone());
            unions
        };

        let _ = collector.add_file(
            "type_map.go",
            render_type_map(
                &go_classes,
                &enums,
                &unions,
                &stream_unions,
                &go_type_aliases,
                &stream_type_aliases,
                go_mod_name,
                &pkg,
            )?,
        );

        pkg.set("baml_client.types");
        let _ = collector.add_file("types/utils.go", render_go_types_utils(&pkg)?);
        let _ = collector.add_file("types/classes.go", render_go_types(&go_classes, &pkg)?);
        let _ = collector.add_file("types/enums.go", render_go_types(&enums, &pkg)?);
        let _ = collector.add_file("types/unions.go", render_go_types(&unions, &pkg)?);
        let _ = collector.add_file(
            "types/type_aliases.go",
            render_go_types(&go_type_aliases, &pkg)?,
        );

        pkg.set("baml_client.types_builder");
        let _ = collector.add_file(
            "type_builder/type_builder.go",
            render_type_builder_common(&enums, &go_classes, &pkg)?,
        );
        let _ = collector.add_file(
            "type_builder/enums.go",
            render_type_builder_enums(&enums, &pkg)?,
        );
        let _ = collector.add_file(
            "type_builder/classes.go",
            render_type_builder_classes(&go_classes, &pkg)?,
        );

        let go_classes = ir
            .walk_classes()
            .map(|c| ir_to_go::classes::ir_class_to_go_stream(c.item, &pkg))
            .collect::<Vec<_>>();

        pkg.set("baml_client.stream_types");
        let _ = collector.add_file("stream_types/utils.go", render_go_stream_types_utils(&pkg)?);
        let _ = collector.add_file(
            "stream_types/classes.go",
            render_go_stream_types(&go_classes, &pkg, go_mod_name)?,
        );
        let _ = collector.add_file(
            "stream_types/unions.go",
            render_go_stream_types(&stream_unions, &pkg, go_mod_name)?,
        );
        let _ = collector.add_file(
            "stream_types/type_aliases.go",
            render_go_stream_types(&stream_type_aliases, &pkg, go_mod_name)?,
        );

        Ok(())
    }
}

#[cfg(test)]
mod go_tests {
    use test_harness::{create_code_gen_test_suites, TestLanguageFeatures};

    impl TestLanguageFeatures for crate::GoLanguageFeatures {
        fn test_name() -> &'static str {
            "go"
        }
    }
    create_code_gen_test_suites!(crate::GoLanguageFeatures);
}

#[cfg(test)]
mod tests {
    #[test]
    fn test_name() {
        use std::str::FromStr;

        use dir_writer::LanguageFeatures;

        let gen_type = baml_types::GeneratorOutputType::from_str(crate::GoLanguageFeatures::name())
            .expect("GoLanguageFeatures name should be a valid GeneratorOutputType");
        assert_eq!(gen_type, baml_types::GeneratorOutputType::Go);
    }
}
